# Load all necessary libraries and dependencies
import os
import cv2
import random
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tqdm import tqdm
# Load Tensorflow layer types
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Cropping2D
from tensorflow.keras.models import Model
# Create face detector object from the Open Computer Vision library
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
# Load in complete dataset of faces
train_celeb, test_celeb = tfds.load("celeb_a",
split = ["train", "test"],
shuffle_files = False,
data_dir = '/kaggle/input/tfds-celeba-dataset',
download = False)
# Function to apply Gaussian blur to faces
def blur_faces(image, face, strength):
# Blur the detected face
for (x, y, w, h) in face:
# Apply blur to boundaries
face_roi = image[y:y+h, x:x+w]
blurred_face = cv2.GaussianBlur(face_roi, (strength, strength), 30)
# Attach blurred face
image[y:y+h, x:x+w] = blurred_face
return image
# Function to pixelate faces
def pixelate_faces(image, face, strength):
# Pixelate the detected face
for (x, y, w, h) in face:
# Apply blur to boundaries
face_roi = image[y:y+h, x:x+w]
face_roi = cv2.resize(face_roi, (w // strength, h // strength))
face_roi = cv2.resize(face_roi, (w, h), interpolation = cv2.INTER_NEAREST)
# Attach pixelated face
image[y:y+h, x:x+w] = face_roi
return image
# Function to apply motion blur faces
def motion_blur_faces(image, face, strength):
# Apply motion blur to the detected face
for (x, y, w, h) in face:
# Apply blur to boundaries
face_roi = image[y:y+h, x:x+w]
kernel_motion_blur = np.zeros((strength, strength))
kernel_motion_blur[int((strength - 1)/2), :] = np.ones(strength)
kernel_motion_blur = kernel_motion_blur / strength
face_roi = cv2.filter2D(face_roi, -1, kernel_motion_blur)
# Attach blurred face
image[y:y+h, x:x+w] = face_roi
return image
# Create a dataset portion
def create_dataset(dataset, image_type, quarter, test_fold):
img_lst = []
total_images = len(dataset)
# Select first & last indices
frst_idx = (total_images // 4) * (quarter - 1)
last_idx = frst_idx + (total_images // 4)
if not test_fold:
total_images = last_idx
for i, image_dict in tqdm(enumerate(tfds.as_numpy(dataset)), total = total_images, desc = "Building Dataset"):
# Skip to appropriate indices
if not test_fold:
# Find appropriate fold of the data
if i < frst_idx:
continue
if i == last_idx:
break
# Filter out lesser quality images
if image_dict['attributes']['Blurry'] or image_dict['attributes']['Eyeglasses']:
continue
image = image_dict['image']
# Calculate the center of the image
center_y = image.shape[0] // 2
center_x = image.shape[1] // 2
# Calculate the top and bottom bounds for cropping
top = max(0, center_y - center_x)
bottom = min(image.shape[0], center_y + center_x)
# Crop the region around the center
cropped_img = image[top:bottom, :, :]
# Detect the face
gray = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, scaleFactor = 1.1, minNeighbors = 5, minSize = (30, 30))
# Only select one face if more are found
if len(faces) == 0:
continue
if len(faces) > 1:
faces = [faces[0]]
# Apply filter
if image_type == 'Gauss':
cropped_img = blur_faces(cropped_img, faces, 5)
elif image_type == 'Pixel':
cropped_img = pixelate_faces(cropped_img, faces, 4)
elif image_type == 'Motion':
cropped_img = motion_blur_faces(cropped_img, faces, 15)
cropped_img = cropped_img.astype(np.float16)
# Normalize pixels
cropped_img /= 255.0
img_lst.append(cropped_img)
return img_lst
# Convert image dataset to TensorFlow dataset
def to_tf_dataset(images):
dataset = tf.data.Dataset.from_tensor_slices(np.asarray(images))
dataset = dataset.map(lambda image: {
'image': image
})
return dataset
I manually ran this code to get all different iterations of all datasets loaded into the necessary format.
target_images1 = create_dataset(train_celeb, "Original", 1, False)
target_images1 = to_tf_dataset(target_images1)
target_images1.save('/kaggle/working/target1', compression = 'GZIP')
del target_images1
Building Dataset: 100%|██████████| 40692/40692 [11:24<00:00, 59.43it/s]
# Load in target data
target1_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/target1/target1', compression = "GZIP")
target2_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/target2/target2', compression = "GZIP")
target3_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/target3/target3', compression = "GZIP")
target4_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/target4/target4', compression = "GZIP")
# Merge target datasets
target_df = target1_df.concatenate(target2_df)
target_df = target_df.concatenate(target3_df)
target_df = target_df.concatenate(target4_df)
# Load in Gaussian blurred data
blurry1_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/blurry1/blurry1', compression = "GZIP")
blurry2_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/blurry2/blurry2', compression = "GZIP")
blurry3_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/blurry3/blurry3', compression = "GZIP")
blurry4_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/blurry4/blurry4', compression = "GZIP")
# Merge Gaussian blurred datasets
blurry_df = blurry1_df.concatenate(blurry2_df)
blurry_df = blurry_df.concatenate(blurry3_df)
blurry_df = blurry_df.concatenate(blurry4_df)
# Load in pixelated data
pixels1_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/pixels1/pixels1', compression = "GZIP")
pixels2_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/pixels2/pixels2', compression = "GZIP")
pixels3_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/pixels3/pixels3', compression = "GZIP")
pixels4_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/pixels4/pixels4', compression = "GZIP")
# Merge pixelated datasets
pixels_df = pixels1_df.concatenate(pixels2_df)
pixels_df = pixels_df.concatenate(pixels3_df)
pixels_df = pixels_df.concatenate(pixels4_df)
# Load in motion blurred data
motion1_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/motion1/motion1', compression = "GZIP")
motion2_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/motion2/motion2', compression = "GZIP")
motion3_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/motion3/motion3', compression = "GZIP")
motion4_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/motion4/motion4', compression = "GZIP")
# Merge motion blurred datasets
motion_df = motion1_df.concatenate(motion2_df)
motion_df = motion_df.concatenate(motion3_df)
motion_df = motion_df.concatenate(motion4_df)
# Load in validation data
val_target = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/target_test/target_test', compression = "GZIP")
val_blurry = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/blurry_test/blurry_test', compression = "GZIP")
val_pixels = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/pixels_test/pixels_test', compression = "GZIP")
val_motion = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/motion_test/motion_test', compression = "GZIP")
# Merge data for training
combined_blurry_df = tf.data.Dataset.zip((blurry_df, target_df))
combined_pixels_df = tf.data.Dataset.zip((pixels_df, target_df))
combined_motion_df = tf.data.Dataset.zip((motion_df, target_df))
# Merge data for validation
val_blurry_df = tf.data.Dataset.zip((val_blurry, val_target))
val_pixels_df = tf.data.Dataset.zip((val_pixels, val_target))
val_motion_df = tf.data.Dataset.zip((val_motion, val_target))
# Define pipelining function
def pair_images(element_a, element_b):
"""Given an blurred element drawn from the CelebA dataset (& its original version),
this returns both images (training and target) together."""
image = element_a['image']
target = element_b['image']
return image, target
# Initialize image-pairing operation for training data
blurry_pipe = combined_blurry_df.map(pair_images)
pixels_pipe = combined_pixels_df.map(pair_images)
motion_pipe = combined_motion_df.map(pair_images)
# Initialize image-pairing operation for validation data
blurry_validation = val_blurry_df.map(pair_images)
pixels_validation = val_pixels_df.map(pair_images)
motion_validation = val_motion_df.map(pair_images)
# Define the titles for each row
titles = ['Original Image', 'Gaussian Blur', 'Pixelation', 'Motion Blur']
# Select 4 random images
selected_indices = random.sample(range(200), 4)
# Plot all images with titles
plt.figure(figsize = (9, 9))
for i, idx in enumerate(selected_indices):
# Plot original images
plt.subplot(4, 4, i + 1, frameon = True)
image = blurry_pipe.skip(idx).take(1)
image = np.expand_dims(next(iter(image))[1], axis = 0)
plt.imshow((image[0] * 255).astype(int))
plt.axis('off')
if i == 0:
# Add title
plt.text(-0.1, 0.5, titles[0], fontsize = 10, ha = 'right', va = 'center',
rotation = 90, transform = plt.gca().transAxes)
# Plot blurry images with titles
plt.subplot(4, 4, i + 5, frameon = True)
image = blurry_pipe.skip(idx).take(1)
image = np.expand_dims(next(iter(image))[0], axis = 0)
plt.imshow((image[0] * 255).astype(int))
plt.axis('off')
if i == 0:
# Add title
plt.text(-0.1, 0.5, titles[1], fontsize = 10, ha = 'right', va = 'center',
rotation = 90, transform = plt.gca().transAxes)
# Plot pixelated images with titles
plt.subplot(4, 4, i + 9, frameon = True)
image = pixels_pipe.skip(idx).take(1)
image = np.expand_dims(next(iter(image))[0], axis = 0)
plt.imshow((image[0] * 255).astype(int))
plt.axis('off')
if i == 0:
# Add title
plt.text(-0.1, 0.5, titles[2], fontsize = 10, ha = 'right', va = 'center',
rotation = 90, transform = plt.gca().transAxes)
# Plot motion blurred images with titles
plt.subplot(4, 4, i + 13, frameon = True)
image = motion_pipe.skip(idx).take(1)
image = np.expand_dims(next(iter(image))[0], axis = 0)
plt.imshow((image[0] * 255).astype(int))
plt.axis('off')
if i == 0:
# Add title
plt.text(-0.1, 0.5, titles[3], fontsize = 10, ha = 'right', va = 'center',
rotation = 90, transform = plt.gca().transAxes)
plt.tight_layout()
plt.show()
def simple_autoencoder(input_shape):
inputs = Input(shape = input_shape)
x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(inputs)
x = MaxPooling2D((2, 2), padding = 'same')(x)
x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(x)
encoded = MaxPooling2D((2, 2), padding = 'same')(x)
# Decoder
x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(3, (3, 3), activation = 'sigmoid', padding = 'same')(x)
cropped_decoded = Cropping2D(cropping=((1, 1), (1, 1)))(decoded)
autoencoder = Model(inputs, cropped_decoded)
autoencoder.compile(optimizer = 'adam', loss = 'mean_squared_error')
return autoencoder
# Define input shape
input_shape = (178, 178, 3)
# Build the autoencoder
autoencoder = simple_autoencoder(input_shape)
print("Number of Parameters:", autoencoder.count_params())
# Train the autoencoder
autoencoder.fit(blurry_pipe.batch(16).prefetch(3),
epochs = 3,
shuffle = True,
verbose = True)
Number of Parameters: 29507 Epoch 1/3 8610/8610 ━━━━━━━━━━━━━━━━━━━━ 272s 31ms/step - loss: 0.0034 Epoch 2/3 8610/8610 ━━━━━━━━━━━━━━━━━━━━ 213s 25ms/step - loss: 0.0011 Epoch 3/3 8610/8610 ━━━━━━━━━━━━━━━━━━━━ 212s 25ms/step - loss: 9.9701e-04
<keras.src.callbacks.history.History at 0x7ae2127a6860>
indices = [0, 21, 24, 30,
0, 21, 24, 30,
0, 21, 24, 30]
fig, axes = plt.subplots(3, 4, figsize = (9,7), subplot_kw = {"xticks": [], "yticks": []})
for i, index in enumerate(indices):
ax = axes.flat[i]
if i == 0:
ax.text(-0.1, 0.5, "Target Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
elif i == 4:
ax.text(-0.1, 0.5, "Gaussian Blur Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
elif i == 8:
ax.text(-0.1, 0.5, "Deblurred Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
if i < 4:
image_with_batch = blurry_pipe.skip(index).take(1)
image_with_batch = np.expand_dims(next(iter(image_with_batch))[1], axis = 0)
ax.imshow((image_with_batch[0] * 255).astype(int))
elif i < 8:
image_with_batch = blurry_pipe.skip(index).take(1)
image_with_batch = np.expand_dims(next(iter(image_with_batch))[0], axis = 0)
ax.imshow((image_with_batch[0] * 255).astype(int))
else:
image_with_batch = blurry_pipe.skip(index).take(1)
image_with_batch = np.expand_dims(next(iter(image_with_batch))[0], axis = 0)
reconstructed_image = autoencoder.predict(image_with_batch)
ax.imshow((reconstructed_image[0] * 255).astype(int))
plt.tight_layout()
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
model_comparison = dict()
def build_autoencoder(input_shape):
inputs = Input(shape = input_shape)
x = Conv2D(128, (3, 3), activation = 'relu', padding = 'same')(inputs)
x = MaxPooling2D((2, 2), padding = 'same')(x)
x = Conv2D(128, (3, 3), activation = 'relu', padding = 'same')(x)
encoded = MaxPooling2D((2, 2), padding = 'same')(x)
# Decoder
x = Conv2D(128, (3, 3), activation = 'relu', padding = 'same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(128, (3, 3), activation = 'relu', padding = 'same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(3, (3, 3), activation = 'sigmoid', padding = 'same')(x)
cropped_decoded = Cropping2D(cropping=((1, 1), (1, 1)))(decoded)
autoencoder = Model(inputs, cropped_decoded)
autoencoder.compile(optimizer = 'adam', loss = 'mean_squared_error')
return autoencoder
# Build the autoencoder
autoencoder = build_autoencoder(input_shape)
# Train the autoencoder
autoencoder.fit(pixels_pipe.batch(16).prefetch(3),
epochs = 3,
shuffle = True,
verbose = True,
validation_data = pixels_validation.batch(16).prefetch(3))
Epoch 1/3 8610/8610 ━━━━━━━━━━━━━━━━━━━━ 806s 91ms/step - loss: 0.0029 - val_loss: 0.0012 Epoch 2/3 8610/8610 ━━━━━━━━━━━━━━━━━━━━ 764s 89ms/step - loss: 0.0012 - val_loss: 0.0011 Epoch 3/3 8610/8610 ━━━━━━━━━━━━━━━━━━━━ 762s 89ms/step - loss: 0.0011 - val_loss: 0.0010
<keras.src.callbacks.history.History at 0x78595c67bdf0>
I iterated through different model architectures, noting down how each was performing. Six different potential architectures were tested in the process.
train_loss = autoencoder.evaluate(pixels_pipe.batch(64).prefetch(3))
valid_loss = autoencoder.evaluate(pixels_validation.batch(64).prefetch(3))
model_comparison.append({"Model Name": "128, 128",
"Parameters": autoencoder.count_params(),
"Train MSE": train_loss,
"Validation MSE": valid_loss})
2153/2153 ━━━━━━━━━━━━━━━━━━━━ 246s 104ms/step - loss: 0.0010 266/266 ━━━━━━━━━━━━━━━━━━━━ 38s 143ms/step - loss: 0.0010
# Convert to DataFrame
df = pd.DataFrame(model_comparison)
# Create a figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (16, 6))
# Bar plot for Number of Parameters
ax1.bar(df['Model Name'], df['Parameters'], color = 'dodgerblue', edgecolor = 'black')
ax1.set_xlabel('Model')
ax1.set_ylabel('Number of Parameters')
ax1.set_title('Number of Parameters in Each Model')
ax1.grid(alpha = 0.5)
# Bar plot for Validation MSE
ax2.bar(df['Model Name'], df['Validation MSE'], color = 'darkviolet', edgecolor = 'black')
ax2.set_xlabel('Model')
ax2.set_ylabel('Validation MSE')
ax2.set_title('Validation MSE for Each Model')
ax2.grid(alpha = 0.5)
plt.tight_layout()
plt.show()
indices = [0, 21, 24, 30,
0, 21, 24, 30,
0, 21, 24, 30]
fig, axes = plt.subplots(3, 4, figsize = (9,7), subplot_kw = {"xticks": [], "yticks": []})
for i, index in enumerate(indices):
ax = axes.flat[i]
if i == 0:
ax.text(-0.1, 0.5, "Target Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
elif i == 4:
ax.text(-0.1, 0.5, "Pixelated Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
elif i == 8:
ax.text(-0.1, 0.5, "Depixelated Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
if i < 4:
image_with_batch = pixels_validation.skip(index).take(1)
image_with_batch = np.expand_dims(next(iter(image_with_batch))[1], axis = 0)
ax.imshow((image_with_batch[0] * 255).astype(int))
elif i < 8:
image_with_batch = pixels_validation.skip(index).take(1)
image_with_batch = np.expand_dims(next(iter(image_with_batch))[0], axis = 0)
ax.imshow((image_with_batch[0] * 255).astype(int))
else:
image_with_batch = pixels_validation.skip(index).take(1)
image_with_batch = np.expand_dims(next(iter(image_with_batch))[0], axis = 0)
reconstructed_image = autoencoder.predict(image_with_batch)
ax.imshow((reconstructed_image[0] * 255).astype(int))
plt.tight_layout()
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
import tensorflow as tf
from tensorflow.keras import layers, models, losses, optimizers
# Define the generator network
def build_generator(input_shape):
model = models.Sequential([
layers.Input(shape = input_shape),
layers.Conv2D(64, (4, 4), strides = (2, 2), padding = 'same', use_bias = False),
layers.BatchNormalization(),
layers.LeakyReLU(negative_slope = 0.2),
layers.Conv2D(128, (4, 4), strides = (2, 2), padding = 'same', use_bias = False),
layers.BatchNormalization(),
layers.LeakyReLU(negative_slope = 0.2),
layers.Conv2DTranspose(64, (4, 4), strides = (2, 2), padding = 'same', use_bias = False),
layers.BatchNormalization(),
layers.ReLU(),
layers.Conv2DTranspose(3, (4, 4), strides = (2, 2), padding = 'same', activation = 'tanh'),
layers.Cropping2D(cropping = ((1, 1), (1, 1)))
])
return model
# Define the discriminator network
def build_discriminator(input_shape):
model = models.Sequential([
layers.Input(shape=input_shape),
layers.Conv2D(32, (4, 4), strides = (2, 2), padding='same'),
layers.LeakyReLU(negative_slope = 0.2),
layers.Conv2D(64, (4, 4), strides = (2, 2), padding='same'),
layers.BatchNormalization(),
layers.LeakyReLU(negative_slope = 0.2),
layers.Conv2D(128, (4, 4), strides = (2, 2), padding='same'),
layers.BatchNormalization(),
layers.LeakyReLU(negative_slope = 0.2),
layers.Flatten(),
layers.Dense(1)
])
return model
# Define the generator loss function
def generator_loss(fake_output):
return losses.mean_squared_error(tf.ones_like(fake_output), fake_output)
# Define the discriminator loss function
def discriminator_loss(real_output, fake_output):
real_loss = losses.mean_squared_error(tf.ones_like(real_output), real_output)
fake_loss = losses.mean_squared_error(tf.zeros_like(fake_output), fake_output)
return real_loss + fake_loss
# Define the generator and discriminator
input_shape = (178, 178, 3)
generator = build_generator(input_shape)
discriminator = build_discriminator(input_shape)
# Define optimizers
discriminator_optimizer = optimizers.Adam()
discriminator_optimizer = optimizers.Adam()
checkpoint = tf.train.Checkpoint(generator_optimizer = generator_optimizer,
discriminator_optimizer = discriminator_optimizer,
generator = generator,
discriminator = discriminator)
# Define the training loop
@tf.function
def train_step(images_blurred, images_clear):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(images_blurred, training = True)
real_output = discriminator(images_clear, training = True)
fake_output = discriminator(generated_images, training = True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
return gen_loss, disc_loss
epochs = 6
def train(dataset, epochs):
for epoch in range(epochs):
progress_bar = tqdm(dataset, desc=f'Epoch {epoch+1}/{epochs}', unit='batch')
for batch in progress_bar:
gen_loss, disc_loss = train_step(batch[0], batch[1])
progress_bar.set_postfix({'Generator Loss ': np.mean(gen_loss.numpy()), 'Discriminator Loss ': np.mean(disc_loss.numpy())})
# Train the model
train(motion_pipe.batch(256).prefetch(3), epochs)
Epoch 1/6: 100%|██████████| 539/539 [09:11<00:00, 1.02s/batch, Generator Loss =0.884, Discriminator Loss =0.27] Epoch 2/6: 100%|██████████| 539/539 [09:03<00:00, 1.01s/batch, Generator Loss =0.631, Discriminator Loss =0.219] Epoch 3/6: 100%|██████████| 539/539 [09:03<00:00, 1.01s/batch, Generator Loss =0.832, Discriminator Loss =0.262] Epoch 4/6: 100%|██████████| 539/539 [09:03<00:00, 1.01s/batch, Generator Loss =0.827, Discriminator Loss =0.189] Epoch 5/6: 100%|██████████| 539/539 [09:02<00:00, 1.01s/batch, Generator Loss =0.832, Discriminator Loss =0.184] Epoch 6/6: 100%|██████████| 539/539 [09:03<00:00, 1.01s/batch, Generator Loss =0.962, Discriminator Loss =0.168]
indices = [17, 19, 27, 39,
17, 19, 27, 39,
17, 19, 27, 39]
fig, axes = plt.subplots(3, 4, figsize = (9,7), subplot_kw = {"xticks": [], "yticks": []})
for i, index in enumerate(indices):
ax = axes.flat[i]
if i == 0:
ax.text(-0.1, 0.5, "Target Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
elif i == 4:
ax.text(-0.1, 0.5, "Motion Blur Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
elif i == 8:
ax.text(-0.1, 0.5, "Deblurred, Artifacts", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
if i < 4:
image_with_batch = motion_validation.skip(index).take(1)
image_with_batch = np.expand_dims(next(iter(image_with_batch))[1], axis = 0)
ax.imshow((image_with_batch[0] * 255).astype(int))
elif i < 8:
image_with_batch = motion_validation.skip(index).take(1)
image_with_batch = np.expand_dims(next(iter(image_with_batch))[0], axis = 0)
ax.imshow((image_with_batch[0] * 255).astype(int))
else:
image_with_batch = motion_validation.skip(index).take(1)
image_with_batch = np.expand_dims(next(iter(image_with_batch))[0], axis = 0)
reconstructed_image = generator.predict(image_with_batch)
ax.imshow((reconstructed_image[0] * 255).astype(int))
plt.tight_layout()
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step